Skip to content

[BugFix][Mamba] Fix da_cumsum kernel to support dt_bias, softplus, and clamp#1118

Merged
Ibuki-wind merged 5 commits intotile-ai:mainfrom
stelladuyx:kernel-fix
May 9, 2026
Merged

[BugFix][Mamba] Fix da_cumsum kernel to support dt_bias, softplus, and clamp#1118
Ibuki-wind merged 5 commits intotile-ai:mainfrom
stelladuyx:kernel-fix

Conversation

@stelladuyx
Copy link
Copy Markdown
Collaborator

@stelladuyx stelladuyx commented Apr 29, 2026

Summary

da_cumsum

Accepts raw dt and applies the full pipeline as compile-time-conditional steps:

  • Add per-head dt_bias (has_dt_bias=True)
  • Softplus with overflow bypass at dt > 20 (dt_softplus=True)
  • Clamp to [dt_min, dt_max]
  • Inclusive prefix sum of dA = dt_out * A

Returns two outputs: dt_out (processed dt) and dA_cumsum. The kernel
signature is fixed regardless of flags; unused inputs are dummy-zeroed at the
op boundary only when has_dt_bias=False — calling forward with
dt_bias=None when has_dt_bias=True now raises ValueError immediately
instead of silently computing un-biased results.

ssd_chunk_scan

Tensor layouts updated to match the official _chunk_scan_fwd:

  • x, C, out changed from chunk-fused [B,C,L,H,P] to seqlen-fused [B,S,H,P]
  • cb changed from head-owned [B,C,H,L,L] to group-owned [B,C,G,L,L]
  • prev_states axis order changed from [B,C,H,N,P] to [B,C,H,P,N] (P before N, official convention)
  • dt layout changed from [B,C,L,H] to [B,H,C,L]
  • n_groups added as a constructor parameter

dA_l shared-memory load moved to just before it is consumed (after the
history path), eliminating a redundant sync_threads stall.

ssd_state_passing

Fixed output convention to match Mamba-2 spec: out[:,c] now holds the state
before chunk c, so out[:,0] = initial_states and out[:,c+1] = s_c for
c in [0, C-2]. Reference implementations in tests and benchmarks updated
to match.

All Mamba kernels

Added @functools.lru_cache to all five kernel factory functions
(da_cumsum, ssd_chunk_scan, ssd_chunk_state, ssd_decode,
ssd_state_passing) to prevent redundant TileLang recompilation on repeated
calls with identical static parameters.

Test plan

  • Run tests/ops/test_mamba.py — all existing and new test cases pass
    (includes new smoke test test_da_cumsum_fwd_missing_bias_raises)
  • Run benchmarks/ops/bench_mamba.py — benchmark executes cleanly with no
    regressions

stelladuyx and others added 2 commits April 29, 2026 15:56
…lamp; update chunk_state, tests, and benchmarks

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…-flag test cases

- Remove is_cuda/dtype/has_dt_bias validation from DaCumsumFwdKernel.forward — per
  project pattern, validation belongs only in Op.forward (system boundary)
- Add bias-only (True,False) and softplus-only (False,True) smoke cases to
  DaCumsumFwdFixture; each guards an independent compile-time branch in the kernel

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
@stelladuyx stelladuyx requested a review from a team April 29, 2026 08:47
@github-actions github-actions Bot added the fix Auto-created by issue labeler label Apr 29, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request updates the da_cumsum operator and kernel to support optional per-head bias, softplus activation, and clamping, aligning the implementation with the Mamba-2 Triton reference. Additionally, the ssd_chunk_state kernel is optimized by reordering the fused axis decoding to improve L2 cache reuse and by eliminating intermediate register fragments for scaled inputs, which allows for larger tile configurations. Feedback was provided regarding the da_cumsum kernel's autotune configuration, noting that since the scan is implemented serially within blocks, using multiple threads leads to redundant work and hardware contention.

Comment thread tileops/kernels/mamba/da_cumsum.py
The inner scan is T.serial(Q): every thread in the block executes the
same loop and writes to the same output locations. Configs with threads
> 1 produce redundant work and write contention with no throughput
benefit. Remove {32, 64, 128} from the autotune search space.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
@stelladuyx stelladuyx self-assigned this May 1, 2026
Copy link
Copy Markdown
Contributor

@Ibuki-wind Ibuki-wind left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall

One correctness blocker remains; address the inline comment and rerun the affected Mamba tests.

Comment thread tileops/kernels/mamba/da_cumsum.py
@stelladuyx stelladuyx requested a review from Ibuki-wind May 9, 2026 03:11
…e configs

- Swap state_tile load loop from (nn,pp) to (pp,nn) so consecutive threads
  iterate over the contiguous N dimension, giving coalesced 128-byte global
  loads instead of strided-by-N accesses.
- Expand autotune_configs: block_n [16,32] -> [32,64,128],
  block_s [32,64] -> [64,128] to cover larger tile sizes used by H200.
@Ibuki-wind Ibuki-wind merged commit 2f6dc23 into tile-ai:main May 9, 2026
11 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

fix Auto-created by issue labeler

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants